{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# ONNX Conversion of the BERT Base Uncased Model\n",
        "This notebook is a companion of chapter 4 of the \"Domain Specific LLMs in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2024.  \n",
        "The code in this notebook is to introduce readers to the [ONNX](https://onnx.ai/) format and [ONNX Runtime](https://onnxruntime.ai/) with the [BERT Base Uncased](https://huggingface.co/google-bert/bert-base-uncased) model. It can be executed in the Colab free tier with hardware acceleration (GPU).  \n",
        "More details about the code can be found in the related book's chapter."
      ],
      "metadata": {
        "id": "U86TB3HLMDGA"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Settings"
      ],
      "metadata": {
        "id": "RAEFQAnnMDCl"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Install the missing requirements in the Colab VM (ONNX, the ONNX runtime and the HF's Datasets)."
      ],
      "metadata": {
        "id": "ucS0i0FcKhyn"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eoHyruS_HajJ"
      },
      "outputs": [],
      "source": [
        "!pip install onnx onnxruntime datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download the BERT Base Uncased model (and associated tokenizer) from the Hugging Face Hub."
      ],
      "metadata": {
        "id": "zAm-HA0GKr1Y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoModelForQuestionAnswering, BertTokenizer\n",
        "\n",
        "model_id = 'google-bert/bert-base-uncased'\n",
        "tokenizer = BertTokenizer.from_pretrained(model_id)\n",
        "model = AutoModelForQuestionAnswering.from_pretrained(model_id)\n",
        "model.eval()"
      ],
      "metadata": {
        "id": "3r2Gvok_Hhke"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download a subset of the SQuAD dataset from the Hugging Face Hub."
      ],
      "metadata": {
        "id": "_31UfS0KK3l9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "samples_count = 200\n",
        "squad = load_dataset(\"squad\", split=\"validation[:\"+ str(samples_count) +\"]\")"
      ],
      "metadata": {
        "id": "6CvicJ-vIlPO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Display one test sample."
      ],
      "metadata": {
        "id": "gFeTyX5zLyLR"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "squad[0]"
      ],
      "metadata": {
        "id": "wwoXesLFLkam"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Benchmark the original model on the selected subset of the squad test set."
      ],
      "metadata": {
        "id": "nt6xHoqteIpa"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "import torch\n",
        "\n",
        "max_seq_length = 128\n",
        "# Measure the latency.\n",
        "latency = []\n",
        "with torch.no_grad():\n",
        "    for i in range(samples_count):\n",
        "        inputs = tokenizer(squad[\"question\"][i], squad[\"context\"][i], return_tensors=\"pt\")\n",
        "        start = time.time()\n",
        "        outputs = model(**inputs)\n",
        "        latency.append(time.time() - start)\n",
        "print(\"PyTorch {} Average inference time = {} ms\".format('CPU', format(sum(latency) * 1000 / len(latency), '.2f')))"
      ],
      "metadata": {
        "id": "5VF5mlXILmLY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Convert the model to ONNX."
      ],
      "metadata": {
        "id": "Pyrb2aV2ePcG"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Create the directory to host the converted model."
      ],
      "metadata": {
        "id": "t6nD0W--xlkS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "output_dir = os.path.join(\".\", \"onnx_models\")\n",
        "if not os.path.exists(output_dir):\n",
        "    os.makedirs(output_dir)\n",
        "export_model_path = os.path.join(output_dir, 'bert-base-uncased.onnx')"
      ],
      "metadata": {
        "id": "7--srbR3eSHd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Pick up one sample from the test dataset."
      ],
      "metadata": {
        "id": "3AT7lgNUxxBT"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "tokenized_inputs = tokenizer(squad[\"question\"][0], squad[\"context\"][0], return_tensors=\"pt\")\n",
        "inputs = {\n",
        "        'input_ids':  tokenized_inputs['input_ids'],\n",
        "        'input_mask': tokenized_inputs['attention_mask'],\n",
        "        'segment_ids': tokenized_inputs['token_type_ids']\n",
        "    }"
      ],
      "metadata": {
        "id": "VydnszbUY8AR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Export the model to ONNX."
      ],
      "metadata": {
        "id": "Q_ZZSDAuyNhb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with torch.no_grad():\n",
        "    symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}\n",
        "    torch.onnx.export(model,\n",
        "                      args=tuple(inputs.values()),\n",
        "                      f=export_model_path,\n",
        "                      opset_version=15,\n",
        "                      do_constant_folding=True,\n",
        "                      input_names=['input_ids',\n",
        "                                       'input_mask',\n",
        "                                       'segment_ids'],\n",
        "                      output_names=['start', 'end'],\n",
        "                      dynamic_axes={'input_ids': symbolic_names,\n",
        "                                    'input_mask' : symbolic_names,\n",
        "                                    'segment_ids' : symbolic_names,\n",
        "                                    'start' : symbolic_names,\n",
        "                                    'end' : symbolic_names})\n",
        "    print(\"Model exported at \", export_model_path)"
      ],
      "metadata": {
        "id": "E-U_HY3UeZ2-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Validate the exported model."
      ],
      "metadata": {
        "id": "pHuwFW6kyYZG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from onnx.checker import check_model\n",
        "\n",
        "check_model(export_model_path, full_check=True)"
      ],
      "metadata": {
        "id": "2Y0x56U9Vgem"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Benchmark the exported model (CPUExecutionProvider)."
      ],
      "metadata": {
        "id": "ZH8wg_bThmP9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import onnxruntime\n",
        "import numpy\n",
        "\n",
        "sess_options = onnxruntime.SessionOptions()\n",
        "\n",
        "sess_options.optimized_model_filepath = os.path.join(output_dir, \"bert-base-uncased.onnx\")\n",
        "\n",
        "session = onnxruntime.InferenceSession(export_model_path, sess_options, providers=['CPUExecutionProvider'])"
      ],
      "metadata": {
        "id": "I4jd006vhvCZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "latency = []\n",
        "for i in range(samples_count):\n",
        "    full_inputs = tokenizer(squad[\"question\"][i], squad[\"context\"][i], return_tensors=\"np\")\n",
        "    ort_inputs = {\n",
        "        'input_ids':  full_inputs['input_ids'],\n",
        "        'input_mask': full_inputs['attention_mask'],\n",
        "        'segment_ids': full_inputs['token_type_ids']\n",
        "    }\n",
        "    start = time.time()\n",
        "    ort_outputs = session.run(None, ort_inputs)\n",
        "    latency.append(time.time() - start)\n",
        "print(\"OnnxRuntime cpu Average inference time = {} ms\".format(format(sum(latency) * 1000 / len(latency), '.2f')))"
      ],
      "metadata": {
        "id": "XvVh7mJVidgP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Verify correctess of the exported model."
      ],
      "metadata": {
        "id": "SURG7c8I-Itx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"***** Verifying correctness *****\")\n",
        "sample_range = 2\n",
        "for i in range(sample_range):\n",
        "    print('PyTorch and ONNX Runtime output {} are close:'.format(i), numpy.allclose(ort_outputs[i], outputs[i].cpu(), rtol=1e-05, atol=1e-04))"
      ],
      "metadata": {
        "id": "ki_sJFNaqkUG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Model Optimization"
      ],
      "metadata": {
        "id": "z_KJL5_n-Q9K"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Optimize the exported model."
      ],
      "metadata": {
        "id": "oA1IxVvz-THI"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from onnxruntime.transformers import optimizer\n",
        "\n",
        "optimized_model_path = os.path.join(output_dir, 'bert-base-uncased.onnx_opt_cpu.onnx')\n",
        "optimized_model = optimizer.optimize_model(export_model_path, model_type='bert', num_heads=12, hidden_size=768)\n",
        "optimized_model.save_model_to_file(optimized_model_path)"
      ],
      "metadata": {
        "id": "YKL7qYcdyo_i"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Benchmark the optimized model (CPUExecutionProvider)."
      ],
      "metadata": {
        "id": "XFh-Jx-w-cTB"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "sess_options_opt = onnxruntime.SessionOptions()\n",
        "\n",
        "sess_options_opt.optimized_model_filepath = os.path.join(output_dir, \"bert-base-uncased.onnx_opt_cpu.onnx\")\n",
        "\n",
        "session_opt = onnxruntime.InferenceSession(export_model_path, sess_options_opt, providers=['CPUExecutionProvider'])"
      ],
      "metadata": {
        "id": "jLK0SgyDVyOt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "latency_opt = []\n",
        "for i in range(samples_count):\n",
        "    full_inputs = tokenizer(squad[\"question\"][i], squad[\"context\"][i], return_tensors=\"np\")\n",
        "    ort_inputs = {\n",
        "        'input_ids':  full_inputs['input_ids'],\n",
        "        'input_mask': full_inputs['attention_mask'],\n",
        "        'segment_ids': full_inputs['token_type_ids']\n",
        "    }\n",
        "    start = time.time()\n",
        "    ort_outputs = session_opt.run(None, ort_inputs)\n",
        "    latency_opt.append(time.time() - start)\n",
        "print(\"OnnxRuntime cpu Average inference time = {} ms\".format(format(sum(latency_opt) * 1000 / len(latency_opt), '.2f')))"
      ],
      "metadata": {
        "id": "hDInVOlXWDLJ"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}